/*
 * Copyright (c) 2013 IBM Corp.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *    Seth Hoenig
 *    Allan Stockdill-Mander
 *    Mike Robertson
 */

package mqtt

import (
	"net"
	"reflect"
	"strings"
	"time"

	"github.com/eclipse/paho.mqtt.golang/packets"
)

const closedNetConnErrorText = "use of closed network connection" // error string for closed conn (https://golang.org/src/net/error_test.go)

// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Paramaters are:
// conn - Connected net.Conn
// cm - Connect Packet with everything other than the protocolname/version populated (historical reasons)
// protocolVersion - The protocol version to attempt to connect with
func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool) {
	switch protocolVersion {
	case 3:
		DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
		cm.ProtocolName = "MQIsdp"
		cm.ProtocolVersion = 3
	case 0x83:
		DEBUG.Println(CLI, "Using MQTT 3.1b protocol")
		cm.ProtocolName = "MQIsdp"
		cm.ProtocolVersion = 0x83
	case 0x84:
		DEBUG.Println(CLI, "Using MQTT 3.1.1b protocol")
		cm.ProtocolName = "MQTT"
		cm.ProtocolVersion = 0x84
	default:
		DEBUG.Println(CLI, "Using MQTT 3.1.1 protocol")
		cm.ProtocolName = "MQTT"
		cm.ProtocolVersion = 4
	}
	if err := cm.Write(conn); err != nil {
		ERROR.Println(CLI, err)
	}

	rc, sessionPresent := verifyCONNACK(conn)
	return rc, sessionPresent
}

// This function is only used for receiving a connack
// when the connection is first started.
// This prevents receiving incoming data while resume
// is in progress if clean session is false.
func verifyCONNACK(conn net.Conn) (byte, bool) {
	DEBUG.Println(NET, "connect started")

	ca, err := packets.ReadPacket(conn)
	if err != nil {
		ERROR.Println(NET, "connect got error", err)
		return packets.ErrNetworkError, false
	}
	if ca == nil {
		ERROR.Println(NET, "received nil packet")
		return packets.ErrNetworkError, false
	}

	msg, ok := ca.(*packets.ConnackPacket)
	if !ok {
		ERROR.Println(NET, "received msg that was not CONNACK")
		return packets.ErrNetworkError, false
	}

	DEBUG.Println(NET, "received connack")
	return msg.ReturnCode, msg.SessionPresent
}

// inbound encapuslates the output from startIncoming.
// err  - If != nil then an error has occured
// cp - A control packet received over the network link
type inbound struct {
	err error
	cp  packets.ControlPacket
}

// startIncoming initiates a goroutine that reads incoming messages off the wire and sends them to the channel (returned).
// If there are any issues with the network connection then the returned cahnnel will be closed and the goroutine will exit
// (so closing the connection will terminate the goroutine)
func startIncoming(conn net.Conn) <-chan inbound {
	var err error
	var cp packets.ControlPacket
	ibound := make(chan inbound)

	DEBUG.Println(NET, "incoming started")
	go func() {
		for {
			if cp, err = packets.ReadPacket(conn); err != nil {
				// We do not want to log the error if it is due to the network connection having been closed
				// elsewhere (i.e. after sending DisconnectPacket). Detecting this situation is the subject of
				// https://github.com/golang/go/issues/4373
				if !strings.Contains(err.Error(), closedNetConnErrorText) {
					ibound <- inbound{err: err}
				}
				close(ibound)
				DEBUG.Println(NET, "incoming complete")
				return
			}
			DEBUG.Println(NET, "Received Message")
			ibound <- inbound{cp: cp}
		}
	}()
	return ibound
}

// incommingComms encapuslates the possible output of the incommingComms routine. If err != nil then an error has occured and
// the routine will have terminated; otherwise one of the other members should be non-nil
type incommingComms struct {
	err          error                  // If non-nil then there has been an error (ignore everything else)
	outbound     *PacketAndToken        // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
	incommingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
}

// startIncommingComms initiates incomming communications; this includes starting a goroutine to process incomming
// messages.
// Accepts a channel of inbound messages from the store (persistanced messages); note this must be closed as soon as the
// everything in the store has been sent.
// Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed)
func startIncommingComms(conn net.Conn,
	c commsFns,
	inboundFromStore <-chan packets.ControlPacket,
) <-chan incommingComms {
	ibound := startIncoming(conn) // Start goroutine that reads from network connection
	output := make(chan incommingComms)

	DEBUG.Println(NET, "startIncommingComms started")
	go func() {
		for {
			if inboundFromStore == nil && ibound == nil {
				close(output)
				DEBUG.Println(NET, "startIncommingComms goroutine complete")
				return // As soon as ibound is closed we can exit (should have already processed an error)
			}
			DEBUG.Println(NET, "logic waiting for msg on ibound")

			var msg packets.ControlPacket
			var ok bool
			select {
			case msg, ok = <-inboundFromStore:
				if !ok {
					DEBUG.Println(NET, "startIncommingComms: inboundFromStore complete")
					inboundFromStore = nil // should happen quickly as this is only for persisted messages
					continue
				}
				DEBUG.Println(NET, "startIncommingComms: got msg from store")
			case ibMsg, ok := <-ibound:
				if !ok {
					DEBUG.Println(NET, "startIncommingComms: ibound complete")
					ibound = nil
					continue
				}
				DEBUG.Println(NET, "startIncommingComms: got msg on ibound")
				// If the inbound comms routine encounters any issues it will send us an error.
				if ibMsg.err != nil {
					output <- incommingComms{err: ibMsg.err}
					continue // Usually the channel will be closed immediatly after sending an error but safer that we do not assume this
				}
				msg = ibMsg.cp

				c.persistInbound(msg)
				c.UpdateLastReceived() // Notify keepalive logic that we recently received a packet
			}

			switch m := msg.(type) {
			case *packets.PingrespPacket:
				DEBUG.Println(NET, "received pingresp")
				c.pingRespReceived()
			case *packets.SubackPacket:
				DEBUG.Println(NET, "received suback, id:", m.MessageID)
				token := c.getToken(m.MessageID)
				switch t := token.(type) {
				case *SubscribeToken:
					DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
					for i, qos := range m.ReturnCodes {
						t.subResult[t.subs[i]] = qos
					}
				}
				token.flowComplete()
				c.freeID(m.MessageID)
			case *packets.UnsubackPacket:
				DEBUG.Println(NET, "received unsuback, id:", m.MessageID)
				c.getToken(m.MessageID).flowComplete()
				c.freeID(m.MessageID)
			case *packets.PublishPacket:
				DEBUG.Println(NET, "received publish, msgId:", m.MessageID)
				output <- incommingComms{incommingPub: m}
			case *packets.PubackPacket:
				DEBUG.Println(NET, "received puback, id:", m.MessageID)
				c.getToken(m.MessageID).flowComplete()
				c.freeID(m.MessageID)
			case *packets.PubrecPacket:
				DEBUG.Println(NET, "received pubrec, id:", m.MessageID)
				prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
				prel.MessageID = m.MessageID
				output <- incommingComms{outbound: &PacketAndToken{p: prel, t: nil}}
			case *packets.PubrelPacket:
				DEBUG.Println(NET, "received pubrel, id:", m.MessageID)
				pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
				pc.MessageID = m.MessageID
				c.persistOutbound(pc)
				output <- incommingComms{outbound: &PacketAndToken{p: pc, t: nil}}
			case *packets.PubcompPacket:
				DEBUG.Println(NET, "received pubcomp, id:", m.MessageID)
				c.getToken(m.MessageID).flowComplete()
				c.freeID(m.MessageID)
			}
		}
	}()
	return output
}

// startOutgoingComms initiates a go routint to transmit outgoing packets.
// Pass in an open network connection and channels for outbound messages (including those triggered
// directly from incomming comms).
// Returns a channel that will receive details of any errors (closed when the goroutine exits)
// This function wil only terminate when all input channels are closed
func startOutgoingComms(conn net.Conn,
	c commsFns,
	oboundp <-chan *PacketAndToken,
	obound <-chan *PacketAndToken,
	oboundFromIncomming <-chan *PacketAndToken,
) <-chan error {
	errChan := make(chan error)
	DEBUG.Println(NET, "outgoing started")

	go func() {
		for {
			DEBUG.Println(NET, "outgoing waiting for an outbound message")

			// This goroutine will only exits when all of the input channels we receive on have been closed. This approach is taken to avoid any
			// deadlocks (if the connection goes down there are limited options as to what we can do with anything waiting on us and
			// throwing away the packets seems the best option)
			if oboundp == nil && obound == nil && oboundFromIncomming == nil {
				DEBUG.Println(NET, "outgoing comms stopping")
				close(errChan)
				return
			}

			select {
			case pub, ok := <-obound:
				if !ok {
					obound = nil
					continue
				}
				msg := pub.p.(*packets.PublishPacket)

				writeTimeout := c.getWriteTimeOut()
				if writeTimeout > 0 {
					if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
						ERROR.Println(NET, err)
					}
				}

				if err := msg.Write(conn); err != nil {
					ERROR.Println(NET, "outgoing reporting error", err)
					pub.t.setError(err)
					// report error if it's not due to the connection being closed elsewhere
					if !strings.Contains(err.Error(), closedNetConnErrorText) {
						errChan <- err
					}
					continue
				}

				if writeTimeout > 0 {
					// If we successfully wrote, we don't want the timeout to happen during an idle period
					// so we reset it to infinite.
					if err := conn.SetWriteDeadline(time.Time{}); err != nil {
						ERROR.Println(NET, err)
					}
				}

				if msg.Qos == 0 {
					pub.t.flowComplete()
				}
				DEBUG.Println(NET, "obound wrote msg, id:", msg.MessageID)
			case msg, ok := <-oboundp:
				if !ok {
					oboundp = nil
					continue
				}
				DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
				if err := msg.p.Write(conn); err != nil {
					ERROR.Println(NET, "outgoing reporting error", err)
					if msg.t != nil {
						msg.t.setError(err)
					}
					errChan <- err
					continue
				}
				switch msg.p.(type) {
				case *packets.DisconnectPacket:
					msg.t.(*DisconnectToken).flowComplete()
					DEBUG.Println(NET, "outbound wrote disconnect, closing connection")
					// As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection"
					// Closing the connection will cause the goroutines to end in sequence (starting with incomming comms)
					conn.Close()
				}
			case msg, ok := <-oboundFromIncomming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
				if !ok {
					oboundFromIncomming = nil
					continue
				}
				DEBUG.Println(NET, "obound from incomming msg to write, type", reflect.TypeOf(msg.p))
				if err := msg.p.Write(conn); err != nil {
					ERROR.Println(NET, "outgoing reporting error", err)
					if msg.t != nil {
						msg.t.setError(err)
					}
					errChan <- err
					continue
				}
			}
			c.UpdateLastSent() // Record that a packet has been received (for keepalive routine)
		}
	}()
	return errChan
}

// commsFns provide access to the client state (messageids, requesting disconnection and updating timing)
type commsFns interface {
	getToken(id uint16) tokenCompletor       // Retrieve the token for the specified messageid (if none then a dummy token must be returned)
	freeID(id uint16)                        // Release the specified messageid (clearing out of any persistant store)
	UpdateLastReceived()                     // Must be called whenever a packet is received
	UpdateLastSent()                         // Must be called whenever a packet is successfully sent
	getWriteTimeOut() time.Duration          // Return the writetimeout (or 0 if none)
	persistOutbound(m packets.ControlPacket) // add the packet to the outbound store
	persistInbound(m packets.ControlPacket)  // add the packet to the inbound store
	pingRespReceived()                       // Called when a ping response is received
}

// startComms initiates goroutines that handles communications over the network connection
// Messages will be stored (via commsFns) and deleted from the store as neccessary
// It returns two channels:
//  packets.PublishPacket - Will receive publish packets received over the network. Closed when incomming comms routines exit (on shutdown or if network link closed)
//  error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down
//
// Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the
// connection being closed and those channels being closed will generate errors (and nothing will be sent). That way the chance of a deadlock is
// minimised.
func startComms(conn net.Conn, // Network connection (must be active)
	c commsFns, // getters and setters to enable us to cleanly interact with client
	inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistance store (should be closed relatively soon after startup)
	oboundp <-chan *PacketAndToken,
	obound <-chan *PacketAndToken) (
	<-chan *packets.PublishPacket, // Publishpackages received over the network
	<-chan error, // Any errors (should generally trigger a disconnect)
) {
	// Start inbound comms handler; this needs to be able to transmit messages so we start a go routine to add these to the priority outbound channel
	ibound := startIncommingComms(conn, c, inboundFromStore)
	outboundFromIncomming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncommingComms (e.g. acknowledgements)

	oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncomming)
	DEBUG.Println(NET, "startComms started")

	// Now we just need to pass on any errors and close the error channel when out inbound channels have been closed
	outPublish := make(chan *packets.PublishPacket)
	outError := make(chan error)
	go func() {
		for {
			if ibound == nil && oboundErr == nil {
				close(outError)
				DEBUG.Println(NET, "startComms gorouting exiting")
				return
			}
			select {
			case ic, ok := <-ibound:
				if !ok {
					ibound = nil
					// As there will be no more inbound messages we can close the two channels that could receive these
					close(outboundFromIncomming)
					close(outPublish)
					break
				}
				if ic.err != nil {
					outError <- ic.err
					break
				}
				if ic.outbound != nil {
					outboundFromIncomming <- ic.outbound
					break
				}
				if ic.incommingPub != nil {
					outPublish <- ic.incommingPub
					break
				}
				ERROR.Println(STR, "startComms received empty incommingComms msg")
			case err, ok := <-oboundErr:
				if !ok {
					oboundErr = nil
					break
				}
				outError <- err
			}
		}
	}()
	return outPublish, outError
}

// ackFunc acknowledges a packet
// WARNING the function returned must not be called if the comms routine is shutting down or not running
// (it needs outgoing comms in order to send the acknowledgement). Currently this is only called from
// matchAndDispatch which will be shutdown before the comms are
func ackFunc(oboundP chan *PacketAndToken, persist Store, packet *packets.PublishPacket) func() {
	return func() {
		switch packet.Qos {
		case 2:
			pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
			pr.MessageID = packet.MessageID
			DEBUG.Println(NET, "putting pubrec msg on obound")
			oboundP <- &PacketAndToken{p: pr, t: nil}
			DEBUG.Println(NET, "done putting pubrec msg on obound")
		case 1:
			pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
			pa.MessageID = packet.MessageID
			DEBUG.Println(NET, "putting puback msg on obound")
			persistOutbound(persist, pa)
			oboundP <- &PacketAndToken{p: pa, t: nil}
			DEBUG.Println(NET, "done putting puback msg on obound")
		case 0:
			// do nothing, since there is no need to send an ack packet back
		}
	}
}
